﻿using DotNetCommon;
using DotNetCommon.Data;
using DotNetCommon.Extensions;
using DotNetCommon.Logger;
using System;
using System.Collections.Concurrent;
using System.Threading;
using System.Threading.Tasks;

namespace DBUtil
{
    public interface ILocker
    {
        void RunInLock(DBAccess db, string lock_str, Action action, int getLockTimeoutSecond = 60 * 3);
        T RunInLock<T>(DBAccess db, string lock_str, Func<T> func, int getLockTimeoutSecond = 60 * 3);
        Task RunInLockAsync(DBAccess db, string lock_str, Func<Task> func, int getLockTimeoutSecond = 60 * 3);
        Task<T> RunInLockAsync<T>(DBAccess db, string lock_str, Func<Task<T>> func, int getLockTimeoutSecond = 60 * 3);
    }

    public class CommonDBLocker : ILocker
    {
        private static readonly ConcurrentDictionary<string, bool> _lockcaches = [];
        private static readonly ConcurrentDictionary<string, SemaphoreSlim> _lockSemaphores = [];
        private readonly ILogger<CommonDBLocker> logger = LoggerFactory.CreateLogger<CommonDBLocker>();

        public string DBLockTableName { get; private set; }
        public CommonDBLocker(string dBLockTableName)
        {
            this.DBLockTableName = dBLockTableName;
        }
        private void EnsureInitLock(DBAccess db)
        {
            if (!_lockcaches.ContainsKey(db.DBConn))
            {
                lock (typeof(DBAccess))
                {
                    if (!_lockcaches.ContainsKey(db.DBConn))
                    {
                        if (!db.IsTableExist(DBLockTableName))
                        {
                            #region 新建表 DBLockTableName
                            db.ExecuteSql($@"
create table {DBLockTableName}(
	lock_str varchar(200) unique not null,
	lock_time {db.DateTimeSqlSegment.DefaultDateTimeType} default({db.DateTimeSqlSegment.Current}),
    lock_createtime {db.DateTimeSqlSegment.DefaultDateTimeType} default({db.DateTimeSqlSegment.Current}),
	lock_user varchar(200) not null
)");
                            #endregion
                        }
                        _lockcaches.TryAdd(db.DBConn, true);
                    }
                }
            }
        }

        private async Task<Result> TryGetLock(DBAccess db, DateTime starttime, string lock_str, string lock_user, int timeoutSecond = 60 * 5)
        {
            EnsureInitLock(db);
            while (true)
            {
                try
                {
                    await db.ExecuteSqlAsync($"insert into {DBLockTableName}(lock_str,lock_user) values('{lock_str}','{lock_user}')");
                    StartLockMonitor(db, lock_str, lock_user);
                    return Result.Ok();
                }
                catch (Exception ex)
                {
                    var endtime = DateTime.Now;
                    var span = endtime - starttime;
                    if (span.TotalSeconds > timeoutSecond)
                    {
                        return Result.NotOk($"获取锁[{lock_str}]超时:{ex?.Message}]");
                    }
                    await Task.Delay(500);
                    //删除20秒之前的,20秒内已获取锁的程序肯定会更新时间
                    try
                    {
                        await db.ExecuteSqlAsync($"delete from {DBLockTableName} where lock_time <{db.DateTimeSqlSegment.GetCurrentAddSecond(-20)}");
                    }
                    catch (Exception ex2)
                    {
                        return Result.NotOk($"获取锁[{lock_str}]失败:{ex2?.Message}]");
                    }
                }
            }
        }

        private void StartLockMonitor(DBAccess db, string lock_str, string lock_user)
        {
            Task.Run(async () =>
            {
                while (true)
                {
                    //分布式时 休息5秒
                    await Task.Delay(5 * 1000);
                    //将当前的锁刷新时长
                    try
                    {
                        int count = await db.ExecuteSqlAsync($"update {DBLockTableName} set lock_time= {db.DateTimeSqlSegment.Current} where lock_str='{lock_str}' and lock_user='{lock_user}'");
                        if (count != 1) break;
                    }
                    catch (Exception ex)
                    {
                        logger.LogError($"刷新锁时间失败(lock_str={lock_str},lock_user={lock_user}):{ex?.Message}");
                        break;
                    }
                }
            });
        }

        private void TryReleaseLock(DBAccess db, string lock_str, string lock_user)
        {
            EnsureInitLock(db);
            db.ExecuteSql($"delete from {DBLockTableName} where lock_str='{lock_str}' and lock_user='{lock_user}'");
        }

        public void RunInLock(DBAccess db, string lock_str, Action action, int getLockTimeoutSecond = 60 * 3)
        {
            Ensure.NotNull(lock_str, nameof(lock_str));
            Ensure.NotNull(action, nameof(action));
            var task = RunInLockInternalAsync(db, lock_str, async () =>
            {
                action();
                return Task.CompletedTask;
            }, getLockTimeoutSecond);
            task.Wait();
        }

        public T RunInLock<T>(DBAccess db, string lock_str, Func<T> func, int getLockTimeoutSecond = 60 * 3)
        {
            Ensure.NotNull(lock_str, nameof(lock_str));
            Ensure.NotNull(func, nameof(func));
            var task = RunInLockInternalAsync(db, lock_str, () =>
            {
                var res = func();
                return Task.FromResult(res);
            }, getLockTimeoutSecond);
            return task.Result;
        }
        public async Task RunInLockAsync(DBAccess db, string lock_str, Func<Task> func, int getLockTimeoutSecond = 60 * 3)
        {
            Ensure.NotNull(lock_str, nameof(lock_str));
            Ensure.NotNull(func, nameof(func));
            await RunInLockInternalAsync(db, lock_str, async () =>
            {
                await func();
                return Task.CompletedTask;
            }, getLockTimeoutSecond);
        }

        public async Task<T> RunInLockAsync<T>(DBAccess db, string lock_str, Func<Task<T>> func, int getLockTimeoutSecond = 60 * 3)
        {
            Ensure.NotNull(lock_str, nameof(lock_str));
            Ensure.NotNull(func, nameof(func));
            return await db.RunInNoSessionAsync(async () => await RunInLockInternalAsync(db, lock_str, func, getLockTimeoutSecond));
        }

        private async Task<T> RunInLockInternalAsync<T>(DBAccess db, string lock_str, Func<Task<T>> func, int getLockTimeoutSecond)
        {
            //先单机拦截并发
            var now = DateTime.Now;
            var asyncLock = _lockSemaphores.GetOrAdd(lock_str, _ => new SemaphoreSlim(1, 1));
            var b = await asyncLock.WaitAsync(getLockTimeoutSecond * 1000);
            if (!b) throw new Exception($"获取锁[{lock_str}]超时,单机内超时.");
            try
            {
                //单机通过后,再数据库锁内拦截
                var lock_user = $"{now.ToCommonStampString()} {Guid.NewGuid().ToString().Replace("-", "").ToLower()}";
                var res = await TryGetLock(db, now, lock_str, lock_user, getLockTimeoutSecond);
                if (res.Success)
                {
                    try
                    {
                        var result = await func();
                        return result;
                    }
                    finally
                    {
                        TryReleaseLock(db, lock_str, lock_user);
                    }
                }
                else
                {
                    throw new Exception(res.Message);
                }
            }
            finally
            {
                try { asyncLock.Release(); } catch { }
            }
        }
    }
}
